diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index e8d0d02..91c663e 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -52,6 +52,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
     #   - should not be a replica due to tensor model parallelism
     grads = []
     grads_for_norm = []
+    grads_in_moe = []
     for param in parameters:
         grad_not_none = param.grad is not None
         is_not_shared = not hasattr(param, 'shared') or not param.shared
@@ -63,7 +64,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
             assert param.grad.type() == 'torch.cuda.FloatTensor'
             grads.append(grad)
         if grad_not_none and is_not_shared and is_not_tp_duplicate:
-            grads_for_norm.append(grad)
+            if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+                grads_in_moe.append(grad)
+            else:
+                grads_for_norm.append(grad)
 
     # Norm parameters.
     max_norm = float(max_norm)
@@ -72,6 +76,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
 
     # Calculate norm.
     if norm_type == inf:
+        # TODO: moe
         total_norm = max(grad.abs().max() for grad in grads_for_norm)
         total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
         # Take max across all model-parallel GPUs.
@@ -96,7 +101,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
             # we need the pow(norm-type).
             total_norm = grad_norm ** norm_type
 
+            if grads_in_moe:
+                grad_norm, _ = multi_tensor_applier(
+                    amp_C.multi_tensor_l2norm,
+                    dummy_overflow_buf,
+                    [grads_in_moe],
+                    False # no per-parameter norm
+                )
+                grad_norm = grad_norm ** norm_type
+                torch.distributed.all_reduce(grad_norm,
+                                             group=mpu.get_model_parallel_group())
+                total_norm += grad_norm
+
         else:
+            # TODO: moe
             for grad in grads_for_norm:
                 grad_norm = torch.norm(grad, norm_type)
                 total_norm += grad_norm ** norm_type